#include <stdio.h>
#include <math.h>
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <unistd.h>
#include <string>
#include <sys/sem.h>
#include <fcntl.h>
#include <signal.h>
#include <sstream>
//REF: https://stackoverflow.com/questions/2279052/increase-stack-size-in-linux-with-setrlimit/2279084#2279084
#include <sys/resource.h>
#include <cmath>
#include <random>
using namespace std;

class CSHM { 
	private :
		int max_size;
		int m_shmid;   
		key_t m_key;
		char *m_shared_memory;
	public : 
		char read_data[50000];
		int getShmId();
		void setKey(key_t key); 
		void setMem(int permission, int r_size);
		void writeMem(string str);
		void readMem();
		void close();
};
void CSHM::setKey(key_t key) {
    m_key = key;
}
void CSHM::setMem(int permission, int r_size) {
	max_size = r_size;
	if ((m_shmid = shmget(m_key, max_size, IPC_CREAT | permission)) < 0) {
		perror("shmget failed ");
		exit(1);
   	}
	if ((m_shared_memory = (char*)(shmat(m_shmid,NULL,0))) == (char *)-1) {
      perror("shmat failed ");
	  shmctl(m_shmid , IPC_RMID, NULL); 
      exit(1);
   	}
}
void CSHM::writeMem(string str) {
	cout << " C write size: " <<  str.size() << endl;
	memcpy(m_shared_memory, str.c_str() , str.size());
}
void CSHM::readMem() {
	memcpy(read_data, m_shared_memory, max_size);
}
void CSHM::close() {
	sleep(3); 
	void* shmdt(void *m_shmid);
	shmctl(m_shmid , IPC_RMID, NULL); 
}

#define SEM_RESOURCE_MAX 1
#define SEM_LOCK_1 {0, -1, SEM_UNDO} 
#define SEM_UNLOCK_1 {0, 1, SEM_UNDO}
#define SEM_LOCK_2 {0, -1, IPC_NOWAIT} 
#define SEM_UNLOCK_2 {0, 1, IPC_NOWAIT}
union semun {
    int val;             
    struct semid_ds *buf;     
    unsigned short int *array; 
	struct seminfo  *__buf;
};

static inline int
sem_init(int *semid, key_t key) {
    if((*semid = semget(key, 1, IPC_CREAT|IPC_EXCL|0606)) == -1) {
        perror("semget failed:");
		exit(1);
    }
    union semun semopts;
    semopts.val = SEM_RESOURCE_MAX;
    semctl(*semid, 0, SETVAL, semopts);
    return 0;
}

static inline void
sem_wait(int *semid) {
	struct sembuf sem_lock = SEM_LOCK_2;
	while (semop(*semid, &sem_lock, 1) == -1) {}
}

static inline void
sem_post(int *semid) {
	struct sembuf sem_lock = SEM_UNLOCK_2;
	while (semop(*semid, &sem_lock, 1) == -1) {}
}

static inline void
sem_destroy(int *semid) {
    semctl(*semid, 0, IPC_RMID, 0);
}


//global variable section 1
CSHM pof_shared_memory_1;
CSHM pof_shared_memory_2;
int sem_1;
int sem_2;
key_t key_1;
key_t key_2;
int permission_1 = 0602;
int permission_2 = 0604;
const int max_intest_size = 5000000 + 5000000;
int default_action = -1000000;
int default_action_using = true;
bool hazard_check = false;
const int n_th = 10;
double th_array[n_th] = {0.00001, 0.0001,0.001,0.01,0.1,0.49,0.499,0.4999,0.49999,1};
const int n_xi = 1;
double xi_array[n_xi] = {0.01}; //#####
double th = -1.0;
double xi = -1.0;
int min_output_class = -1;

//global variable section 2
const int n_sample_of_action = 252;
const int n_NN = 1;
int n_rank = 2;
int n_test_case[n_NN][2], n_class = 2;
int ans[n_NN][max_intest_size], ans_sum[n_NN][2];
double beta;
double lambda;
double mu[2] = {}, sigma[2]={};
double ccon_w[n_NN][n_sample_of_action][5], test_case_ccon_w[n_NN][max_intest_size][5];
int normt_eff[n_NN][max_intest_size][5][2];
double main_approxloss;
int act = -1;
int mode = -1;
int intest_unit_number_1 = -1;
int intest_unit_number_2 = -1;
int arrival_intest_size = 0;


void signal_callback_handler(int signum) {
	cout << "Caught signal " << signum << endl;
	sem_destroy(&sem_1);
	sem_destroy(&sem_2);
	cout << "DESTROY S1 & S2" << endl;
	pof_shared_memory_1.close();
	pof_shared_memory_2.close();
	cout << "DESTROY SHM1 & SHM2" << endl;
	exit(signum);
}


void pof_output_reading() {
	string shared_string(pof_shared_memory_1.read_data);
	string temp;
	stringstream pof_output(shared_string);
	getline(pof_output, temp);
	mode = stoi(temp);
	getline(pof_output, temp);
	arrival_intest_size = stoi(temp);
	intest_unit_number_1 = int(arrival_intest_size*0.5);
	intest_unit_number_2 = int(arrival_intest_size*0.5);
	getline(pof_output, temp);
	default_action = stoi(temp);
	getline(pof_output, temp);
	mu[0] = stod(temp);
	getline(pof_output, temp);
	mu[1] = stod(temp);
	getline(pof_output, temp);
	sigma[0] = stod(temp);
	getline(pof_output, temp);
	sigma[1] = stod(temp);
	getline(pof_output, temp);
	n_class = stoi(temp);
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_sample_of_action; j++) {
			for (int k = 0; k < n_class; k++) {
				getline(pof_output, temp);
				ccon_w[i][j][k] = max(-100., min(stod(temp), 100.))/2.;
			}
		}
	}
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_rank; j++) {
			getline(pof_output, temp);
			n_test_case[i][j] = stoi(temp);
		}
	}
	getline(pof_output, temp);
	min_output_class = stod(temp);
}


void pof_update_writing() {
	string shared_string = "";
	shared_string += to_string(act);
    shared_string += "\n";
	shared_string += to_string(default_action_using);
	shared_string += "\n";
	pof_shared_memory_2.writeMem(shared_string);
}


class table {
	public:
		int results[n_NN][2][5][2] = {};
		void reset() { memset(results, 0, sizeof(results)); }
};
class posterior {
	public:
		double post_prob[n_NN][5][2] = {};
};
class classassign {
	public:
		int assign[n_NN][n_sample_of_action];
		void reset() { memset(assign, 0, sizeof(assign)); }
};

//global variable section 3
table normaltable;
classassign normalclass;


void calc_normalclass() {
	for (int i = 0; i < n_NN; i++) {
		for (int j = 0; j < n_sample_of_action; j++) {
			double max = -5000000;
			for (int k = 0; k < n_class; k++) {
				if (ccon_w[i][j][k] > max) {
					normalclass.assign[i][j] = k;
					max = ccon_w[i][j][k];
				}
			}
		}
	}
}

double approxloss(posterior p, classassign c, int is_main = 0) {
	int tmp = default_action;
	int index_tmp = 0;
	if (default_action < -10000) {
		hazard_check = true;
		if (is_main) act = tmp + 10000;
		return tmp + 10000;
	}
	int check_view[n_sample_of_action]={};
	int cnt=0;
	double sum_unnormalized_prob=0;
	double unnormalized_prob[500][500]={};
	double max_tmp = -1e9;
	for(int i = 0; i < 376; i++) {
		for(int j = 0; j < 376; j++) {
			int ref_sample = 126 * (i >= 188) + (j > 250 ? 125 : (j > 125 ? j-125 : 0));
			if(c.assign[index_tmp][ref_sample] == min_output_class)
				max_tmp = max_tmp > -(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1]?max_tmp:-(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1];
		}
	}
	for(int i=0;i<376;i++) {
		for(int j=0;j<376;j++){
			int ref_sample=126*(i>=188)+(j>150?125:(j>125?j-125:0));
			check_view[ref_sample]=1;
			if(c.assign[index_tmp][ref_sample] == min_output_class) {
				unnormalized_prob[i][j]=exp(-(i*0.016-3-mu[0])*(i*0.016-3-mu[0])/sigma[0]/sigma[0]-(j*0.016-3-mu[1])*(j*0.016-3-mu[1])/sigma[1]/sigma[1]-max_tmp);
				sum_unnormalized_prob+=unnormalized_prob[i][j];
				cnt++;
			}
		}
	}
	if (cnt) {
		if (is_main) default_action_using = false;
		std::random_device rd;
		std::uniform_real_distribution<double> distr(0, sum_unnormalized_prob);
		double sample=distr(rd);
		for(int i=0;i<376;i++) {
			for(int j=0;j<376;j++) {
				sample-=unnormalized_prob[i][j];
				if (sample < 0) {
					tmp=i*376+j;
					goto A;
				}
			}
		}
	}
	A:cout << "sampling action: " << tmp << endl;
	if (is_main) act = tmp;
	return cnt;
}

int initialize_variables() {
	//global variable section 2
	memset(ans_sum, 0, sizeof(ans_sum));
	memset(normt_eff, 0, sizeof(normt_eff));
	act = -1;
	mode = -1;
	//global variable section 3
	normaltable.reset();
	normalclass.reset();
	//reset default_action
	default_action = -1000000;
	default_action_using = true;
	hazard_check = false;
	return 0;
}


int main(int argc, char** argv)
{
	if (argc != 5) {
        cout << "NUM_ARG ERROR" << endl;
        exit(1);
    }
	string exp_name = argv[1];
	string checkpoint = argv[2];
	int key_offset = stoi(argv[3]);
	int thread_order = stoi(argv[4]);
	string grad_path = "./" + exp_name + "/validation_" + checkpoint + "/";
	cout << grad_path << " " << key_offset << " " << thread_order << endl;
	th = th_array[thread_order % n_th];
	xi = xi_array[thread_order / n_th];
	FILE* ACTCNT = nullptr; 
	signal(SIGINT, signal_callback_handler);
	key_1=(key_offset + thread_order)*1157;
	key_2=(key_offset + thread_order)*1257;
	pof_shared_memory_1.setKey(key_1);
	pof_shared_memory_1.setMem(permission_1, 50000); // u-g-o(rw-w-w) => 602
	pof_shared_memory_2.setKey(key_2);
	pof_shared_memory_2.setMem(permission_2, 100); // u-g-o(rw-r-r) => 604
	sem_init(&sem_1, key_1);
	sem_init(&sem_2, key_2);
	sem_wait(&sem_1);
	sem_wait(&sem_2);
	std::cout << "STAND BY - S1-0 & S2-0 !!!" << endl;
	srand((unsigned int)time(NULL));
	while (true) {
		// std::cout << "OUTPUT WAIT ---" << endl; 
		sem_wait(&sem_1);
		pof_shared_memory_1.readMem();
		pof_output_reading(); // S1-0 & S2-0
		if (mode == -1) break;
		// std::cout << "OUTPUT READING COMPLETE" << endl; 
		calc_normalclass();
		posterior posterior_tmp;
		int actcnt = approxloss(posterior_tmp, normalclass, 1);
		
		pof_update_writing();
		// cout << "UPDATE WRITING COMPLETE" << endl;
		initialize_variables();
		sem_post(&sem_2); // S1-0 & S2-1
	}
	sem_destroy(&sem_1);
	sem_destroy(&sem_2);
	fclose(ACTCNT);
	cout << "DESTROY S1 & S2" << endl;
	pof_shared_memory_1.close();
	pof_shared_memory_2.close();
	cout << "DESTROY SHM1 & SHM2" << endl;
	return 0;
}